from collections import deque
from quant import Order
from quant.const import OrderStatus, OrderType


class SpreadMaker:
    def __init__(self, markets, accounts, config):
        self.markets = markets

        self.exchange = config['exchange']
        self.symbol = config['symbol']
        self.amount = config['amount']
        self.max_pos = config['max_pos']        # 最大仓位
        self.hold = config['hold']              # 底仓
        self.spread = config['spread']          # 挂单距离
        self.var = config['var']                # 忽略微小价格变动
        self.asset = self.symbol.split('/')[0]  # 品种asset，例如'btc/usdt'交易对则为'btc'

        self.price_prec = config['price_precision']    # 用于格式化价格
        self.amount_prec = config['amount_precision']  # 用于格式化数量
        self.pricing = AvgCalculator(10)               # 用于计算fair_price

        key = config['key']
        self.acc = accounts.create_account(self.exchange, key)
        self.init_acc()

    def init_acc(self):
        """
        相较合约，现货无需查询position, 无需查询margin，需要查询balance
        """
        self.acc.add_symbol(self.symbol)
        self.acc.api.query_balance()
        self.acc.api.query_open_orders()
        self.acc.api.join()

    def start(self):
        self.markets.add_market('Book', self.exchange, self.symbol)
        self.markets.subscribe_all(self.on_market_data)

    def on_market_data(self, market_data):
        book = market_data.data
        if book is None:
            self.pause()
            print('接口已断开，等待自动重连')
            return

        mid = book.get_middle()
        self.pricing.push(mid)
        fair = self.pricing.calculate()
        if fair is None:
            return  # 启动未满1s，暂不执行策略

        buy_price = fair - self.spread
        sell_price = fair + self.spread
        buy_amt, sell_amt = self.cal_amt()

        # 格式化价格、数量
        buy_price = round(buy_price, self.price_prec)
        sell_price = round(sell_price, self.price_prec)
        buy_amt = round(buy_amt, self.amount_prec)
        sell_amt = round(sell_amt, self.amount_prec)

        order_by_side = {}
        for order in self.acc.orders.values():
            order_by_side[order.side] = order

        buying_order = order_by_side.get('buy')
        selling_order = order_by_side.get('sell')

        self.operate(buying_order, 'buy', buy_price, buy_amt)
        self.operate(selling_order, 'sell', sell_price, sell_amt)

    def pause(self):
        """
        暂时停止交易
        """
        orders = self.acc.orders
        for o in orders.values():
            self.try_cancel(o)

    def cal_amt(self):
        """
        计算下单量
        """
        amt = self.amount
        hold = self.hold
        max_ = self.max_pos

        bal = self.acc.balance[self.asset]
        bal = bal['free'] + bal['frozen']   # 计算btc余额
        pos = bal - hold                    # 仓位，模仿合约概念，方便计算。pos ∈ [-max_, +max_]

        buy_amt = amt
        sell_amt = amt
        if pos + amt > max_:    # 本次买入将使仓位超过max_
            buy_amt = max_ - pos
        if pos - amt < -max_:   # 本次卖出将使仓位低于-max_
            sell_amt = max_ - (-pos)
        return buy_amt, sell_amt

    def operate(self, pending_order, side, price, amount):
        to_place = Order(
            symbol=self.symbol,
            side=side,
            price=price,
            amount=amount,
            order_type=OrderType.PostOnly,
        )

        if pending_order is None:
            if amount > 0:
                self.acc.api.place_order(to_place)
        else:
            if amount < 0:
                self.try_cancel(pending_order)
            elif abs(pending_order.price - price) > self.var:  # 价格变动较大
                self.try_cancel(pending_order)

    def try_cancel(self, order):
        if order.status in (OrderStatus.Placing, OrderStatus.Canceling):
            return  # 订单仍在路由中，本次暂不处理
        self.acc.api.cancel_order(order)


class AvgCalculator:
    """
    均价计算器，默认参数10，对于100ms数据则为1s均价
    """
    def __init__(self, size=10):
        self.size = size
        self.deque = deque(maxlen=size)

    def push(self, mid):
        self.deque.append(mid)

    def calculate(self):
        if len(self.deque) < self.size:
            return None
        else:
            return sum(self.deque) / self.size
